#from __future__ import division

import mpctools as mpc
import casadi 
import numpy as np
import time
from scipy import linalg
from numpy import random
import matplotlib.pyplot as plt
np.set_printoptions(precision=4) 

#from numpy import random
#random.seed(2) 

##################################################################################################
"   import new subsystem and measurement models, one requirement is that the model is normalized               "
"   w.r.t. time. That is, Delta = 1 denotes one sampling time chosen for the considered system model.          "
"   For example, in the ode models for WWTP, delta = 1 corresponds to 15 min. The size of the sampling time    "
"   should be carefully picked and then scale the model w.r.t. time correspondingly.                           "

from C_4CSTR_model_norm import ode_sub1_4cstr
from C_4CSTR_model_norm import measurement_sub_1

##################################################################################################

#random.seed(927) # Seed random number generator.

fullInformation = False # True for full information estimation, False for MHE.

Delta = 1/60/2 #1 # Time step, if the system model is already normalized w.r.t. time, Delta is picked as 1 (Now, it is 1h/1800 = 2s)
Tsim = 2 # Total simulation time (h)
Nsim = int(Tsim/Delta) #1345 # Number of data points
tplot = np.arange(Nsim+1)*Delta

No_subsys = 1   # number of subsystems we have
Nt_sub = np.array([40])  # Here, enter the estimation horizon length for all the subsystems
Nu = 4 
Nx = 8
Nx_sub = np.array([8])  # Number of subsystem states, depending on how the subsystems are configured
Nx_total = sum(Nx_sub)

Ny_sub = np.array([2]) # Number of available measurements in subsystems, depending on how the subsystems are configured
Nw_sub = np.array(Nx_sub)
Nv_sub = np.array(Ny_sub)

Nu_sub = np.array([Nx_total + Nu]) # DO NOT NEED TO MODIFY number of shared variables for each local estimator

# Tuning parameters of MHE and process disturbance
sigma_v = 0.004 # Standard deviation of the measurements, 0.004
sigma_w = 0.005 # Standard deviation for the process noise, 0.005
sigma_p = 0.1  # Standard deviation for prior (used in arrival cost)

# Make covariance matrices for subsystems
P_all = []
P_sub = []
T10 = np.ones((Nx_sub[0],))
P_sub = np.diag((sigma_p*T10)**2)
P_all.append(P_sub)

### scale the intial state of Subsystems
x0_sub1 = np.array([2.78, 363, 2.58, 356, 2.6, 355, 2.6, 392]) # initial value
x0_scale_sub1 = np.zeros(Nx_sub[0])
xs_sub1 = x0_sub1
min_x_sub1 = np.zeros(Nx_sub[0])
max_x_sub1 = np.zeros(Nx_sub[0])
lb = 0.5
ub = 1.5
for i in range(Nx_sub[0]):
    min_x_sub1[i] = lb*xs_sub1[i]
    max_x_sub1[i] = ub*xs_sub1[i]

delta_x_sub1 = max_x_sub1 - min_x_sub1
for i in range(Nx_sub[0]):
    x0_scale_sub1[i] = (x0_sub1[i]-min_x_sub1[i])/delta_x_sub1[i]
x_0_sub1 = x0_scale_sub1
#x_0_sub1 = np.array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
# Make a simulator for MHE for three subsystems.

MHE_sub1_casadi = mpc.DiscreteSimulator(ode_sub1_4cstr, Delta, [Nx_sub[0],Nu_sub[0],Nw_sub[0]], ["x","u","w"])
#For MHE of each subsystems, Convert continuous-time f to explicit discrete-time F with RK4.
F_sub1 = mpc.getCasadiFunc(ode_sub1_4cstr,[Nx_sub[0],Nu_sub[0],Nw_sub[0]],["x","u","w"],"F_sub1",rk4=False,Delta=Delta)#,M=2)
H_sub1 = mpc.getCasadiFunc(measurement_sub_1,[Nx_sub[0]],["x"],"H_sub1")

# Define stage costs of the MHE.
Sigma_wIN = np.zeros((Nw_sub[0],Nw_sub[0]))
for i in range(Nw_sub[0]):
    Sigma_wIN[i,i] = (delta_x_sub1[i]*sigma_w)**-1
    
def lfunc_sub1(w,v):
    return casadi.sumsqr(Sigma_wIN*w)+ sigma_v**-2*casadi.sumsqr(v)



l_sub1 = mpc.getCasadiFunc(lfunc_sub1,[Nw_sub[0],Nv_sub[0]],["w","v"],"l_sub1")

def lxfunc_sub1(x):
    return mpc.mtimes(x.T,linalg.inv(P_all[0]),x)
lx_sub1 = mpc.getCasadiFunc(lxfunc_sub1,[Nx_sub[0]],["x"],"lx_sub1")


# Generate subsystem disturbances and measurement noise#
w_sub1 = sigma_w*random.randn(Nsim,Nw_sub[0])
for i in range(Nw_sub[0]):
    w_sub1[:,i] = delta_x_sub1[i]*w_sub1[:,i]
v_sub1 = sigma_v*random.randn(Nsim,Nv_sub[0])

xsim_sub1 = np.zeros((Nsim+1,Nx_sub[0]))
xsim_sub1 [0,:] = x_0_sub1            # initial value for simulating the acutal process trajectory (sub1)
yclean_sub1 = np.zeros((Nsim, Ny_sub[0]))
ysim_sub1 = np.zeros((Nsim, Ny_sub[0]))


# Four heating inputs to the Four tanks
Q1 = 1*(10**4)
Q2 = 2*(10**4)
Q3 = 2.5*(10**4)
Q4 = 1*(10**4)

## Initialize the shared information available to subsystems for state trajectory generation
usim_state_generate = np.zeros((Nsim+1,Nu_sub[0]))

usim_state_generate[0,0:8] = x0_sub1
usim_state_generate[:,8] = Q1    # use this line if the Four manipuated inputs are constant
usim_state_generate[:,9] = Q2
usim_state_generate[:,10] = Q3
usim_state_generate[:,11] = Q4      # use this line if the Four manipuated inputs are constant

#===================================================================================================================#

x_actual_sub1 = np.zeros((Nsim+1,Nx_sub[0]))

for i in range(Nx_sub[0]): 
    x_actual_sub1 [0,i] = xsim_sub1 [0,i] * delta_x_sub1 [i] + min_x_sub1 [i]   
# Generate the process dynamics (scaled and in-original coordinate) using each subsystem model
for t in range(Nsim):
    yclean_sub1[t,:] = measurement_sub_1(xsim_sub1[t]) # Get zero-noise measurement.
    ysim_sub1[t,:] = yclean_sub1[t,:] + v_sub1[t,:] # Add noise to measurement.
    xsim_sub1[t+1,:] = MHE_sub1_casadi.sim(xsim_sub1[t,:],usim_state_generate[t,:],w_sub1[t,:])
#    xsim_sub1[t+1,:] = xsim_sub1[t,:] + Delta*ode_sub1_4cstr(xsim_sub1[t,:],usim_state_generate[t,:],w_sub1[t,:])

    ## Recover the actual state dynamics based on the scaled states at each sampling time
    #  Recover Subsystem 1 estimate and actual dynamics
    for i in range(Nx_sub[0]):
       x_actual_sub1 [t+1,i] = xsim_sub1 [t+1,i] * delta_x_sub1 [i] + min_x_sub1 [i]
    # Recover subsystem 2 estimate and actual dynamics
    usim_state_generate[t+1,0:8] = casadi.vertcat(x_actual_sub1[t+1,:]).T
    
    
##===================================================================================================================#
x_actual_entire_sys = np.zeros((Nsim+1,Nx_total))
x_actual_entire_sys = usim_state_generate[:,0:8]

#x_actual_entire_sys = np.loadtxt('xp_true.txt')
#x_actual_sub1 = x_actual_entire_sys 
#ysim_sub1 = np.loadtxt('y.txt')
#usim_state_generate = np.loadtxt('usim.txt')
np.savetxt("x_actual_entire_sys.txt",x_actual_entire_sys)
#===================================================================================================================#


xhat_sub1 = np.zeros((Nsim,Nx_sub[0]))
x0bar_sub1 = 1.1*x_0_sub1             # initial guess for MHE-1 (sub1)
xhat_sub1[0,:] = x0bar_sub1
guess_sub1 = {}

solveroptions = {
            'linear_solver':'mumps'
            }

totaltime = -time.time()

# Build and Update the inputs to each subsystem MHE which contain lastest estimates from other MHEs
# u_mhe_all_subs contain the actual input values and the subsystem estimates
u_mhe_all_subs = np.zeros((Nsim+1,Nu_sub[0]))
u_mhe_all_subs[:,8] = usim_state_generate[:,8]
u_mhe_all_subs[:,9] = usim_state_generate[:,9]
u_mhe_all_subs[:,10] = usim_state_generate[:,10]
u_mhe_all_subs[:,11] = usim_state_generate[:,11]
u_mhe_all_subs[0,0:8] = casadi.vertcat(x0bar_sub1).T

x_est_stage = np.zeros((Nsim+1,Nu_sub[0]))
x_est_stage[:,8] = usim_state_generate[:,8]
x_est_stage[:,9] = usim_state_generate[:,9]
x_est_stage[:,10] = usim_state_generate[:,10]
x_est_stage[:,11] = usim_state_generate[:,11]
x_est_stage[0,0:8] = u_mhe_all_subs[0,0:8]   # x_est_stage is the lastest estimate for iterative evaluation
########################################################################################################

# Define the matrices of the state and estiamte trajectories
x_actual_hat_sub1 = np.zeros((Nsim,Nx_sub[0]))

# Begin to solve the distributed MHE

for t in range(Nsim):

    N_sub1 = {"x":Nx_sub[0], "y":Ny_sub[0], "u":Nu_sub[0], "c":4}
    N_sub1["t"] = min(t,Nt_sub[0])
    tmin_sub1 = max(0,t - Nt_sub[0])
    tmax_sub1 = t+1
    lb_sub1 = {"x":np.zeros((N_sub1["t"] + 1,Nx_sub[0]))}

## add more dictionary here if there are more estimators


    buildtime_mhe1 = -time.time()
    contargs_sub1 = dict(
        u=x_est_stage[tmin_sub1:t,:],
        y=ysim_sub1[tmin_sub1:tmax_sub1,:], l=l_sub1, N=N_sub1, lx=lx_sub1,
        x0bar = x0bar_sub1, verbosity=0, guess = guess_sub1,
        lb=lb_sub1
        )
    solver_sub1 = mpc.nmhe(f=F_sub1, h=H_sub1, Delta=Delta,**contargs_sub1)
    buildtime_mhe1 += time.time()

    solvetime_mhe1 = -time.time()
    solver_sub1.initialize(solveroptions=solveroptions)
    sol_sub1 = mpc.callSolver(solver_sub1)
    solvetime_mhe1 += time.time()

    print ("%3d (%5.3g s build, %5.3g s solve_mhe1): %s"
           % (t, buildtime_mhe1, solvetime_mhe1, sol_sub1["status"]))
#==============================================================================#
    # update a transient variable xhat_sub using the new state estiamte sequences
    # Sub 1
    if t < Nt_sub[0]:
        for j in range(t+1):
            xhat_sub1[t-j,:] = sol_sub1["x"][-1-j]
    else:
        for j in range(Nt_sub[0]+1):
            xhat_sub1[t-j,:] = sol_sub1["x"][-1-j]

    # Save stuff to use as a guess. Cycle the guess.
    guess_sub1 = {}
    for k in set(["x","w","v"]).intersection(sol_sub1.keys()):
        guess_sub1[k] = sol_sub1[k].copy()


#PanZC add 窗口起始点设为下次初值
    if t>Nt_sub[0]:
        x0bar_sub1 = guess_sub1['x'][1,0:8]

    if not fullInformation and t + 1 > Nt_sub[0]:    # This is MHE, not full information estimator
        for k in guess_sub1.keys():
            guess_sub1[k] = guess_sub1[k][1:,...] # Get rid of oldest measurement.

     # Add final guess state for new time point.
    for k in guess_sub1.keys():
        guess_sub1[k] = np.concatenate((guess_sub1[k],guess_sub1[k][-1:,...]))

# Recover the actual state dynamics based on the scaled states at each sampling time

    # Sub 1
    if t < Nt_sub[0]:
        for j in range(t+1):
            xhat_sub1[t-j,:] = sol_sub1["x"][-1-j]
    else:
        for j in range(Nt_sub[0]+1):
            xhat_sub1[t-j,:] = sol_sub1["x"][-1-j]

# Recover Subsystem 1 estimate and actual dynamics
    if t < Nt_sub[0]:
        for i in range(Nx_sub[0]):
            for j in range(t+1):
                x_actual_hat_sub1 [t-j,i] = xhat_sub1 [t-j,i] * delta_x_sub1 [i] + min_x_sub1 [i]
    else:
        for i in range(Nx_sub[0]):
            for j in range(Nt_sub[0]+1):
                x_actual_hat_sub1 [t-j,i] = xhat_sub1 [t-j,i] * delta_x_sub1 [i] + min_x_sub1 [i]

    x_est_stage[t,0:8] = x_actual_hat_sub1[t,:]

    u_mhe_all_subs[t,0:8] = x_est_stage[t,0:8]    # u_mhe_all_subs saves final estiamtes, is updated every iteration in terms of only the lastest estimates

totaltime += time.time()

print("Simulation took %.5g s." % totaltime)


### RMSE
x_estimate_entire = np.zeros((Nsim,Nx_total))
x_estimate_entire = u_mhe_all_subs[:,0:8]

np.savetxt("x_esti_cent.txt",x_estimate_entire)

x_true = np.zeros((Nsim, 8))
x_true[:,0:8] = x_actual_sub1[:-1,0:8]
x_esti= np.zeros((Nsim, 8))
x_esti[:,0:8] = x_estimate_entire[:-1,0:8]


RMSE_x_mhe = np.zeros(Nsim)
RMSE_xp_mhe = np.zeros(Nsim)
for i in range(Nsim):
    for j1 in range(Nx):
        RMSE_x_mhe[i] = RMSE_x_mhe[i]+((x_esti[i,j1]-x_true[i,j1])/x_true[i,j1])**2
#    RMSE_x_mhe[i] = RMSE_x_mhe[i]/(Nx)

sum_RMSE_x_mhe = 0
for i in range(Nsim):
    sum_RMSE_x_mhe = sum_RMSE_x_mhe+RMSE_x_mhe[i]
    
ave_RMSE_x_mhe = np.sqrt( sum_RMSE_x_mhe/(Nsim))
print("ave_RMSE_x_mhe =",ave_RMSE_x_mhe)
## RSME revised

### plt RMSE
#fig9 = plt.figure(9)
#x1ax = plt.subplot(211)
#x1ax.plot(tplot[0:Nsim],RMSE_x_mhe,'k',label='Case 1')
#plt.legend(loc='upper right')  #legend location
#plt.setp(x1ax.get_xticklabels(), visible=False)
#plt.ylabel('$RMSE_x$')
##x2ax = plt.subplot(212,sharex=x1ax)
##plt.plot(tplot[0:Nsim],RMSE_p_mhe,'k')
##plt.setp(x1ax.get_xticklabels(), visible=False)
##plt.ylabel('$RMSE_p$')
#x3ax = plt.subplot(212,sharex=x1ax)
#plt.plot(tplot[0:Nsim],RMSE_xp_mhe,'k')
#plt.xlim(0.0,Delta*(Nsim))
#plt.xlabel('Time(h)')
#plt.ylabel('$RMSE_{xp}$')
#plt.savefig('fig9.pdf', bbox_inches='tight', format='pdf')#, dpi=1000)


##================================================================================================##

## plot including the states and estimates in the first CSTR
## 4个状态画在一起
plt.figure(1)
ax1 = plt.subplot(221)
ax1.plot(tplot[:-1], x_actual_sub1[:-1,0], color="red", label="actual state")
ax1.plot(tplot[:-1], x_estimate_entire[:-1,0], color= "green", markersize=1,
    markeredgecolor="green", linestyle="--", label="state estimate" )
ax1.set_ylabel("$C_{A1}$")
plt.legend(loc='lower right', fontsize=8)
ax2 = plt.subplot(222)
ax2.plot(tplot[:-1], x_actual_sub1[:-1,2], color="red")
ax2.plot(tplot[:-1], x_estimate_entire[:-1,2], color= "green", markersize=1,
    markeredgecolor="green", linestyle="--")
ax2.set_ylabel("$C_{A2}$")
ax3 = plt.subplot(223)
ax3.plot(tplot[:-1], x_actual_sub1[:-1,1], color="red")
ax3.plot(tplot[:-1], x_estimate_entire[:-1,1],  color= "green", markersize=1,
    markeredgecolor="green", linestyle="--")
ax3.set_ylabel("$T_1$")
plt.xlabel('Time(h)')
ax4 = plt.subplot(224)
ax4.plot(tplot[:-1], x_actual_sub1[:-1,3], color="red")
ax4.plot(tplot[:-1], x_estimate_entire[:-1,3], color= "green", markersize=1,
    markeredgecolor="green", linestyle="--")
ax4.set_ylabel("$T_2$")
ax4.legend(loc="center left", bbox_to_anchor=(1.01, 0.5))
plt.xlabel('Time(h)')
plt.tight_layout()   #调整子图间距以最小化重叠
plt.savefig('CT12.eps', bbox_inches='tight', format='eps')

plt.figure(2)
ax1 = plt.subplot(221)
ax1.plot(tplot[:-1], x_actual_sub1[:-1,4], color="red", label="actual state")
ax1.plot(tplot[:-1], x_estimate_entire[:-1,4], color= "green", markersize=1,
    markeredgecolor="green", linestyle="--", label="state estimate" )
ax1.set_ylabel("$C_{A3}$")
plt.legend(loc='lower right', fontsize=8)
ax2 = plt.subplot(222)
ax2.plot(tplot[:-1], x_actual_sub1[:-1,6], color="red")
ax2.plot(tplot[:-1], x_estimate_entire[:-1,6], color= "green", markersize=1,
    markeredgecolor="green", linestyle="--")
ax2.set_ylabel("$C_{A4}$")
ax3 = plt.subplot(223)
ax3.plot(tplot[:-1], x_actual_sub1[:-1,5], color="red")
ax3.plot(tplot[:-1], x_estimate_entire[:-1,5], color= "green", markersize=1,
    markeredgecolor="green", linestyle="--")
ax3.set_ylabel("$T_3$")
plt.xlabel('Time(h)')
ax4 = plt.subplot(224)
ax4.plot(tplot[:-1], x_actual_sub1[:-1,7], color="red")
ax4.plot(tplot[:-1], x_estimate_entire[:-1,7], color= "green", markersize=1,
    markeredgecolor="green", linestyle="--")
ax4.set_ylabel("$T_4$")
ax4.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.xlabel('Time(h)')
plt.tight_layout()   #调整子图间距以最小化重叠
plt.savefig('CT34.eps', bbox_inches='tight', format='eps')

#[fig, ax] = plt.subplots(nrows=2)
#x1ax = ax[0]
#x2ax = ax[1]
#    
#x1ax.plot(tplot[:-1], x_actual_sub1[:-1,0], color="red", label="actual")
#x1ax.plot(tplot[:-1], x_estimate_entire[:-1,0], marker="o", color= "green", markersize=1.5,
#    markeredgecolor="green", linestyle="", label="estimate" )
#mpc.plots.zoomaxis(x1ax, xscale=1.05, yscale=1.05)
#x1ax.set_ylabel("1st state in sub1")
#x1ax.legend(loc="center left", bbox_to_anchor=(1.01, 0.5))
#
#x2ax.plot(tplot[:-1], x_actual_sub1[:-1,1], color="red", label="actual")
#x2ax.plot(tplot[:-1], x_estimate_entire[:-1,1], marker="o", color= "green", markersize=1.5,
#     markeredgecolor="green", linestyle="", label="estimate")
#mpc.plots.zoomaxis(x2ax, xscale=1.05, yscale=1.05)
#x2ax.set_ylabel("2nd state in sub1")
#x2ax.legend(loc="center left", bbox_to_anchor=(1.01, 0.5))
#    
#    
#[fig, ax] = plt.subplots(nrows=2)
#x3ax = ax[0]
#x4ax = ax[1] 
#    
#x3ax.plot(tplot[:-1], x_actual_sub1[:-1,2], color="red", label="actual")
#x3ax.plot(tplot[:-1], x_estimate_entire[:-1,2], marker="o", color= "green", markersize=1.5,
#    markeredgecolor="green", linestyle="", label="estimate")
#mpc.plots.zoomaxis(x3ax, xscale=1.05, yscale=1.05)
#x3ax.set_ylabel("3nd state in sub1")
#x3ax.legend(loc="center left", bbox_to_anchor=(1.01, 0.5))
#
#x4ax.plot(tplot[:-1], x_actual_sub1[:-1,3], color="red", label="actual")
#x4ax.plot(tplot[:-1], x_estimate_entire[:-1,3], marker="o", color= "green", markersize=1.5,
#     markeredgecolor="green", linestyle="", label="estimate")
#mpc.plots.zoomaxis(x4ax, xscale=1.05, yscale=1.05)
#x4ax.set_ylabel("4nd state in sub1")
#x4ax.legend(loc="center left", bbox_to_anchor=(1.01, 0.5))
#    
#    
#[fig, ax] = plt.subplots(nrows=2) 
#x5ax = ax[0]
#x6ax = ax[1]
#    
#x5ax.plot(tplot[:-1], x_actual_sub1[:-1,4], color="red", label="actual")
#x5ax.plot(tplot[:-1], x_estimate_entire[:-1,4], marker="o", color= "green", markersize=1.5,
#         markeredgecolor="green", linestyle="", label="estimate")
#mpc.plots.zoomaxis(x5ax, xscale=1.05, yscale=1.05)
#x5ax.set_ylabel("5nd state in sub1")
#x5ax.legend(loc="center left", bbox_to_anchor=(1.01, 0.5))
#
#x6ax.plot(tplot[:-1], x_actual_sub1[:-1,5], color="red", label="actual")
#x6ax.plot(tplot[:-1], x_estimate_entire[:-1,5], marker="o", color= "green", markersize=1.5,
#         markeredgecolor="green", linestyle="", label="estimate")
#mpc.plots.zoomaxis(x6ax, xscale=1.05, yscale=1.05)
#x6ax.set_ylabel("6nd state in sub1")
#x6ax.legend(loc="center left", bbox_to_anchor=(1.01, 0.5))    
#    
#    
#[fig, ax] = plt.subplots(nrows=2)
#x7ax = ax[0]
#x8ax = ax[1]
#    
#x7ax.plot(tplot[:-1], x_actual_sub1[:-1,6], color="red", label="actual")
#x7ax.plot(tplot[:-1], x_estimate_entire[:-1,6], marker="o", color= "green", markersize=1.5,
#         markeredgecolor="green", linestyle="", label="estimate")
#mpc.plots.zoomaxis(x7ax, xscale=1.05, yscale=1.05)
#x7ax.set_ylabel("7nd state in sub1")
#x7ax.legend(loc="center left", bbox_to_anchor=(1.01, 0.5))   
#
#x8ax.plot(tplot[:-1], x_actual_sub1[:-1,7], color="red", label="actual")
#x8ax.plot(tplot[:-1], x_estimate_entire[:-1,7], marker="o", color= "green", markersize=1.5,
#         markeredgecolor="green", linestyle="", label="estimate")
#mpc.plots.zoomaxis(x8ax, xscale=1.05, yscale=1.05)
#x8ax.set_ylabel("8nd state in sub1")
#x8ax.legend(loc="center left", bbox_to_anchor=(1.01, 0.5))      
#    
